from pathlib import Path
from threading import Thread
from typing import Any, Optional

from ctransformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Config
from transformers import TextIteratorStreamer, TextStreamer

from bot.client.llm_client import LlmClient, LlmClientType
from bot.model.model import Model


class CtransformersClient(LlmClient):
    streamer: Optional[TextIteratorStreamer] = None

    def __init__(self, model_folder: Path, model_settings: Model):
        if LlmClientType.CTRANSFORMERS not in model_settings.clients:
            raise ValueError(
                f"{model_settings.file_name} is a not supported " f"by the {LlmClientType.CTRANSFORMERS.value} client."
            )
        self.config = Config(**model_settings.config)
        super().__init__(model_folder, model_settings)

    def _load_llm(self) -> Any:
        llm = AutoModelForCausalLM.from_pretrained(
            model_path_or_repo_id=str(self.model_folder),
            model_file=self.model_settings.file_name,
            model_type=self.model_settings.type,
            config=AutoConfig(config=self.config),
            hf=True,
        )
        return llm

    def _load_tokenizer(self) -> Any:
        tokenizer = AutoTokenizer.from_pretrained(self.llm)
        return tokenizer

    def _encode_prompt(self, prompt: str):
        """
        Encodes the given prompt using the model's tokenizer.

        Args:
            prompt (str): The input prompt to be encoded.

        Returns:
            torch.Tensor: The input IDs tensor generated by the tokenizer.
        """
        return self.tokenizer(text=prompt, return_tensors="pt").input_ids

    def _decode_answer(self, prompt_ids, answer_ids):
        """
        Decodes the answer IDs tensor into a human-readable answer using the model's tokenizer.

        Args:
            prompt_ids (torch.Tensor): The prompt IDs tensor used for generating the answer.
            answer_ids (torch.Tensor): The answer IDs tensor generated by the language model.

        Returns:
            str: The decoded answer.
        """
        return self.tokenizer.batch_decode(answer_ids[:, prompt_ids.shape[1] :])[0]

    def generate_answer(self, prompt: str, max_new_tokens: int = 512) -> str:
        """
        Generates an answer based on the given prompt using the language model.

        Args:
            prompt (str): The input prompt for generating the answer.
            max_new_tokens (int): The maximum number of new tokens to generate (default is 512).

        Returns:
            str: The generated answer.
        """
        prompt_ids = self._encode_prompt(prompt)
        answer_ids = self.llm.generate(prompt_ids, max_new_tokens=max_new_tokens)
        answer = self._decode_answer(prompt_ids, answer_ids)

        return answer

    async def async_generate_answer(self, prompt: str, max_new_tokens: int = 512) -> str:
        """
        Generates an answer based on the given prompt using the language model.

        Args:
            prompt (str): The input prompt for generating the answer.
            max_new_tokens (int): The maximum number of new tokens to generate (default is 512).

        Returns:
            str: The generated answer.
        """
        prompt_ids = self._encode_prompt(prompt)
        answer_ids = self.llm.generate(prompt_ids, max_new_tokens=max_new_tokens)
        answer = self._decode_answer(prompt_ids, answer_ids)

        return answer

    def stream_answer(self, prompt: str, skip_prompt: bool = True, max_new_tokens: int = 512) -> str:
        """
        Generates an answer by streaming tokens using the TextStreamer.

        Args:
            prompt (str): The input prompt for generating the answer.
            skip_prompt (bool): Whether to skip the prompt tokens during streaming (default is True).
            max_new_tokens (int): The maximum number of new tokens to generate (default is 512).

        Returns:
            str: The generated answer.
        """
        streamer = TextStreamer(tokenizer=self.tokenizer, skip_prompt=skip_prompt)

        prompt_ids = self._encode_prompt(prompt)
        answer_ids = self.llm.generate(prompt_ids, streamer=streamer, max_new_tokens=max_new_tokens)

        answer = self._decode_answer(prompt_ids, answer_ids)

        return answer

    def start_answer_iterator_streamer(
        self, prompt: str, skip_prompt: bool = True, max_new_tokens: int = 512
    ) -> TextIteratorStreamer:
        """
        Starts an answer iterator streamer thread for generating answers asynchronously.

        Args:
            prompt (str): The input prompt for generating the answer.
            skip_prompt (bool): Whether to skip the prompt tokens during streaming (default is True).
            max_new_tokens (int): The maximum number of new tokens to generate (default is 512).

        Returns:
            str: An empty string as a placeholder for the return value.
        """
        self.streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=skip_prompt)

        inputs = self.tokenizer(prompt, return_tensors="pt")
        kwargs = dict(
            input_ids=inputs["input_ids"],
            streamer=self.streamer,
            max_new_tokens=max_new_tokens,
        )
        thread = Thread(target=self.llm.generate, kwargs=kwargs)
        thread.start()

        return self.streamer

    async def async_start_answer_iterator_streamer(
        self, prompt: str, skip_prompt: bool = True, max_new_tokens: int = 512
    ) -> TextIteratorStreamer:
        """
        Starts an answer iterator streamer thread for generating answers asynchronously.

        Args:
            prompt (str): The input prompt for generating the answer.
            skip_prompt (bool): Whether to skip the prompt tokens during streaming (default is True).
            max_new_tokens (int): The maximum number of new tokens to generate (default is 512).

        Returns:
            str: An empty string as a placeholder for the return value.
        """
        self.streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=skip_prompt)

        inputs = self.tokenizer(prompt, return_tensors="pt")
        kwargs = dict(
            input_ids=inputs["input_ids"],
            streamer=self.streamer,
            max_new_tokens=max_new_tokens,
        )
        thread = Thread(target=self.llm.generate, kwargs=kwargs)
        thread.start()

        return self.streamer

    def parse_token(self, token):
        return token
